Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Yet another batch normalization PR #3229

Merged
merged 2 commits into from
Oct 23, 2015
Merged

Conversation

cdoersch
Copy link
Contributor

This PR squashes together #1965 and #3161 to make sure that proper credit is given. The final functionality is much more like #3161: we ultimately decided that the scale/shift could be implemented as a separate layer (and should hence get its own PR) and the data shuffling, if it gets merged, should also be done as a separate PR (I have not reviewed that code closely enough to say whether it is mergeable). This version includes the global stats computations, and fixes the issue where #3161 was using the biased variance estimate (took a little while to convince myself that this is indeed the correct estimator to use).

It would be great if @ducha-aiki and @jeffdonahue could take a look at this.

}
}

layer {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This prototxt does not work, because params taken from #1965.
Corrected one is here https://gist.github.com/ducha-aiki/6457bbd49fea8b7634a7

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch; that totally slipped my mind.

@ducha-aiki
Copy link
Contributor

@cdoersch Looks good for me things I have commented.

if (bottom[0]->num_axes() == 1)
channels_ = 1;
else
channels_ = bottom[0]->channels();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bottom[0]->shape(1) for ND compatibility. (Ideally there'd be an axis parameter with default 1 to specify which axis the stats are computed along, but that can always come later, and a Reshape will work, just less conveniently.)

caffe_cpu_axpby(mean_.count(), Dtype(1), mean_.cpu_data(),
moving_average_fraction_, this->blobs_[0]->mutable_cpu_data());
Dtype m = Dtype(bottom[0]->count()/channels_);
caffe_cpu_axpby(variance_.count(), m/(m-1), variance_.cpu_data(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if the m in the unbiased variance correction scalar m/(m-1) should be the "batch size" rather than the total number of samples, since the "pixels" within a given batch item are not going to be IID?

edit: Looking at the batch norm paper, this seems to be what they do -- in the algorithm box at the bottom of pg 4, step 10 has the unbiased variance computation and uses m/(m-1) where m is the batch size.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the paper is actually rather unclear about this. Right above the algorithm box, it says "We have m values of this activation in the mini-batch", which doesn't make it sound like it's num_width_height.

I also thought we had some evidence that, at least as far as second-order statistics go, the features within a channel are quite uncorrelated with each other, which would suggest num/(num-1) is too severe a correction. But I'm not sure. Is there a reference implementation anywhere?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, reading further I think what you have here is indeed what the paper does:

In Alg. 1, we let B be the set of
all values in a feature map across both the elements of a
mini-batch and spatial locations – so for a mini-batch of
size m and feature maps of size p × q, we use the effective
mini-batch of size m′ = |B| = m · p q.

Intuitively it seems strange to me -- I would think that for many types of filters, neighboring activations within a feature map would be highly correlated. I don't recall seeing evidence to the contrary (or supporting me, this is just my intuition) but would be interested if you or anyone has a reference.

Regardless, I'll retract my suggested change as it seems like this is in fact what the batch norm paper does. Perhaps it could later be made configurable if the alternative is ever found to be useful.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jeffdonahue:

I would think that for many types of filters, neighboring activations within a feature map would be highly correlated. I don't recall seeing evidence to the contrary (or supporting me, this is just my intuition) but would be interested if you or anyone has a reference.

I was also looking for studies of how correlated (how dependent, actually) neighboring results of convolutions (and activations) are within a feature map. My intuition also tells me they should be highly dependent.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll admit that I originally thought they would be highly correlated too. However, at one point I was trying to figure out whether I needed to do whitening on conv5 features extracted from patches before plugging them into some classifier, and so I actually measured it. I don't remember the actual numbers, but I remember being amazed at how small the correlations were. Even for neighboring units within the same channel, the correlations were something like R=.05 or .1. Somebody should repeat the experiment to make sure of this, though; this was a year and a half ago, so my memory is pretty faded.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When I looked through layers while learning and checking fully convolutional nets I remember seeing results like this too. At some point it'd be nice to go back and investigate the spatial correlations and effective receptive field sizes of different layers.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not surprising that the neighboring units have very little or no correlation. Stacked layers of CNNs actually work more alike very strong matched filters that have the property of strong spatially-localized responses. This strong spatial localization is the basis behind the ability to regress using Fully-connected layers for the object location using CNN feature maps as the input. Also, as we move up in the layer stack a pixel-difference starts corresponding to multiple pixels differences in the original image, therefore, small spatial is expected naturally.
Note - This comment is based on the observation of some feature maps, regression based object localization performance and some intuition developed along the way working with deep CNNs.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Intuition kinda says that in the first convolutional layer neighboring units are highly dependent for real (non-random) filters and real images (if for no other reason that neighboring image pixels are highly dependent, as are filter weights). As you move up the layer stack, both correlation and dependency should and do decrease (corresponding to "understanding is a compressed representation" paradigm), although they will still be dependent for homogeneous areas of the image corresponding to perceptive field.

This decrease in correlation/dependency may well be critical (who knows) for FC to work and, if measured precisely enough, may well tell us when the right time is to transition to an FC layer. Although the key for the transition seems to be sparsity (causing independence as a side-effect), so independence without sparsity probably has comparatively little value anyway.

@jeffdonahue
Copy link
Contributor

Thanks for putting this PR together @cdoersch and for the early work @ducha-aiki! Besides above comment this looks good to me. I haven't tried the examples, but if they really don't converge I'm not sure yet how we should handle that...they probably shouldn't be merged as is; broken examples aren't great for most users... Perhaps they should temporarily use the 1x1 convolution hack until we have the dedicated layers merged as well? (ParameterLayer (#2079) and ScalarLayer (#3021) would handle the scaling parameter; a BiasLayer should be PRed at some point...)

@ducha-aiki
Copy link
Contributor

@jeffdonahue they sucessfully converge and actually work (when use fixed example). What is not working - in-place computation, which I propose for now just to flag with
CHECK_NE(top[0], bottom[0]) << this->type() << " Layer does not allow in-place computation.";

@shelhamer
Copy link
Member

@cdoersch @ducha-aiki my experiments with in-place computation of this layer do not converge either although I have had convergence with this version of batch norm derived from #1965 https://github.com/HyeonwooNoh/caffe/blob/master/src/caffe/layers/bn_layer.cu. It could be worth a glance before merge.

@jeffdonahue
Copy link
Contributor

Ah, I see, thanks for the clarification/reiteration @ducha-aiki. I agree with your suggestion in that case -- there should either be such a check (and ideally the then unnecessary variance caching would be removed), or we should try to fix in-place computation.

lr_mult: 2
}
convolution_param {
num_output: 32
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should have bias_term: false in the Convolution layers before BatchNorms, since the effect is cancelled out by mean subtraction, no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point--wasted computation.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for updating -- the second param also needs to be removed though (there's a CHECK failure if not), and preferably the bias_filler should be removed too (though it has no effect).

top: "ip1"
param {
lr_mult: 1
decay_mult: 250
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This decay_mult setting and the bias lr_mult setting are pretty weird... @ducha-aiki was there a reason for these?

@ducha-aiki
Copy link
Contributor

@cdoersch now it works :)
@jeffdonahue thanks for catch, it was artifact of search and replace. I have cleaned up (also set bias_lr to zero for batch-normalized net)
Corrected definitions and trains logs are at https://gist.github.com/ducha-aiki/c0d1325f0cebe0b05c36

@cdoersch
Copy link
Contributor Author

@jeffdonahue @ducha-aiki I've fixed the lr's and decays. Can you confirm that I've made all the required changes?

@ducha-aiki
Copy link
Contributor

@cdoersch LGTM 👍

@jeffdonahue
Copy link
Contributor

Thanks again @cdoersch and @ducha-aiki! LGTM as well.

jeffdonahue added a commit that referenced this pull request Oct 23, 2015
Yet another batch normalization PR
@jeffdonahue jeffdonahue merged commit 39f69fb into BVLC:master Oct 23, 2015
@siddharthm83
Copy link

@ducha-aiki , That makes sense.
So I changed it as per above and I find that theEltwiseAffineLayer only works for me in CPU mode (loss is converging). In GPU mode, the loss is constant at 87.3365 and does not decrease. In fact, this is the same number I got with 2 different net architecture (one was caffenet with batchnorm and eltwiseaffine). I added a batchnorm layer and eltwiseaffine layer after all relu's. Interestingly enough if I have only 1 batchnorm followed by eltwiseaffine, it works. So perhaps, the other layers are compensating for whatever wrong that's happening.

I will send some sample snippets and prototxt file in your PR page.

@happynear
Copy link

@ducha-aiki

Adding BN without scale/shift before ReLU really hurts the complexity of the model, since the output of BN is expected to have zero mean, making the bias term in convolution layer meaningless.

However, when the BN layer is added after ReLU, there is no need to append a scale/shift layer after it because scale/shift are also linear transformations. In the Batch Normalization paper, they did not do any experiments to analyse where to put the BN layer is the best. They claimed that BN + ReLU can produce a stable distribution. I can't understand this. I am looking forward for your results.

@ducha-aiki
Copy link
Contributor

@happynear @siddharthm83 @cdoersch @beniz @shelhamer
Here results of tests.
https://github.com/ducha-aiki/batchnorm-benchmark
Not yet with EltwiseAffine, because each training takes 48 hours.

@cuihenggang
Copy link

Hi, I assume the BatchNormalization layer is pretty much done (is it?). I'm wonder has anyone tried training the Ilsvrc12 task using the Inception-BN network? What validation accuracies have we got here?

@cdoersch
Copy link
Contributor Author

Ok, I chatted with some other BVLC folks and it sounds like we're going to go ahead and merge some kind of per-channel scale/shift layers. What PRs do we currently have that do this?

I'm currently aware of #2996, as well as #2079/#3021. I also vaguely remember someone referencing a PR with separate scaling and shifting layers, but I can't find it now, so maybe I'm imagining things.

Let's first try to come to some consensus about which one should be merged, and then I'll review it. Right now I think #2996 looks pretty straightforward.

@ducha-aiki
Copy link
Contributor

@cdoersch I am for separating #2996 into bias and scale, but also OK with it as now (but rebased to current). So should I start doing this?
@shelhamer @ronghanghu

@cdoersch
Copy link
Contributor Author

I'm not sure I see the point in separating it into two layers. I think it would be better if there's just options in the prototxt to turn off the bias or scale and save the computation if desired. I think the general use-case for this will involve both of them, and combining them should save memory.

I might pick a different name though--EltwiseAffine makes it sound like something different is happening for each element, when it's really a per-channel operation. Maybe PerChannelAffine?

@ducha-aiki
Copy link
Contributor

@cdoersch agree, flags for turn off/on are better, than separation.
ChannelWiseAffine?

@cdoersch
Copy link
Contributor Author

@ducha-aiki I guess to follow the 'Eltwise' pattern, it should be ChanwiseAffine. Probably ChannelwiseAffine would more clear though.

@ducha-aiki
Copy link
Contributor

@cdoersch OK, then I will clean it up, add flags and rebase.

@jeffdonahue
Copy link
Contributor

@cdoersch ScalarLayer (#3021) and BiasLayer (in the newly-created #3550) are what I've been using to learn the BN params. I'd be interested to know how the performance and memory use compares with the combined approach in #2996 from @ducha-aiki.

@ducha-aiki
Copy link
Contributor

@jeffdonahue I can test both variants.

P.S. caffenet128 training with BN-EA layer is almost come to finish and it looks like EA helps at least with BN before non-linearity setup. Will see if it helps for BN-after-ReLU, which performs much better.

@jeffdonahue
Copy link
Contributor

Thanks @ducha-aiki, that would be great. For reference, this Python function (using NetSpec) should do in-place batch norm with ScalarLayer + BiasLayer:

def batch_norm(x, with_params=True, lr_mult=1, in_place=True):
    param = [dict(lr_mult=0)] * 3
    out = L.BatchNorm(x, param=param, in_place=in_place)
    if with_params:
        param = [dict(lr_mult=lr_mult, decay_mult=0)]
        kwargs = dict(axis=1, num_axes=1, param=param, in_place=in_place)
        out = L.Scalar(out, **kwargs)
        out = L.Bias(out, **kwargs)
    return out

@ducha-aiki
Copy link
Contributor

@jeffdonahue
The memory consumption is identical up to 1 Mb for BN-caffenet. I just have nvidia-smi when training networks.

Speed:
CA:
I0114 18:22:07.118582 11780 caffe.cpp:358] ChannelwiseAffine1 forward: 0.3755 ms.
I0114 18:22:07.118588 11780 caffe.cpp:361] ChannelwiseAffine1 backward: 1.17972 ms.

S+B
I0114 18:23:03.240345 11875 caffe.cpp:358] Scalar1 forward: 0.352228 ms.
I0114 18:23:03.240351 11875 caffe.cpp:361] Scalar1 backward: 0.521535 ms.
I0114 18:23:03.240358 11875 caffe.cpp:358] Bias1 forward: 0.176595 ms.
I0114 18:23:03.240365 11875 caffe.cpp:361] Bias1 backward: 0.72729 ms.
Sum:
forward: 0.528823 ms
backward: 1.248825 ms

So my implementation is faster a bit.

@ducha-aiki
Copy link
Contributor

BN+EA

@nian-liu
Copy link

Hi everyone, I am not clear what the param "moving_average_fraction" means and how to determine its value, could anyone give me any hint?

@classner
Copy link

classner commented Feb 3, 2016

@nian-liu It is the weight with which the'old' moving average parts are down-weighed every iteration, i.e., (1-moving_average_fraction) gives a measure for the speed of decay of the mean (the higher, the higher the decay).

I just observed that this layer is not using the running mean and variance with exponential decay during the training. This means, that it becomes 'risky to use' (to say the least) with very small batch sizes, in the most extreme case with batch size one (this has the nice advantage for semantic segmentation that the network input can be dynamically resized). It can especially lead to large discrepancies between training and testing performance, when the per batch statistics do not approximate the global statistics well.

What was the reasoning behind the decision to do this? It requires the assumption that the mean and variance over a batch are a reasonable estimate for the mean and variance of the entire dataset to hold. This may or may not be the case, and increasingly not ;) for small batch sizes.

@cuihenggang
Copy link

So do we have a new train_val protobuf for ImageNet that includes scale/shift operations after batch normalization?

@cuihenggang
Copy link

Actually while looking at the BatchNormLayer implementation, I find some issues.

I find there is a default option of "use_global_stats_" that let the BatchNormLayer store and use the moving average of the mean and variance values. I think in the original paper, they normalize only inside the mini-batch, without considering the previous samples, which makes their normalized mini-batch white (mean=0 and variance=1). I think with the use of moving average of mean and variance, the output of our BatchNormLayer won't be white, because they are not the real mean and variance of this mini-batch. Will this cause any problems?

@classner
Copy link

Actually, by default during training it does not use the moving average, during testing it does. That may lead to exactly the problem I described above...

@cuihenggang
Copy link

Actually I have a question regarding to the BatchNormLayer design. Are there any specific reasons why we choose to implement scale and bias in a separate ScaleLayer, rather than implementing it inside the BatchNormLayer? Aren't we consuming more memory from adding an extra ScaleLayer after each BatchNormLayer?

@d4nst
Copy link

d4nst commented Jun 15, 2016

I am currently testing some nets with and without batch normalization and I see that the memory consumption for nets with batch normalization is twice as much. For example, using this resnet I can train with a batch size of 24 images in my GPU. However, if I remove all the BatchNorm layers I can use up to 60 images or so. The problem seems to be in the BatchNorm backward pass.

What is the reason for this? It seems like a very high memory compsumption.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.